{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "091cde80",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Model, T5Tokenizer\n",
    "from transformers.models.bart.modeling_bart import shift_tokens_right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at mesolitica/t5-super-tiny-bahasa-cased were not used when initializing T5Model: ['lm_head.weight']\n",
      "- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d84cd3986ef4949a9da7846f95827e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading spiece.model:   0%|          | 0.00/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "939fdf1c1a6f4d03ae519446ae27b209",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading special_tokens_map.json:   0%|          | 0.00/2.15k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a56e7af1c444045a21f415536c0fb88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer_config.json:   0%|          | 0.00/2.35k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = T5Model.from_pretrained(\"mesolitica/t5-super-tiny-bahasa-cased\")\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"mesolitica/t5-super-tiny-bahasa-cased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = ['i like', 'i hate so much']\n",
    "initial_text = 'terjemah Inggeris ke Melayu: '\n",
    "input_ids = [{'input_ids': tokenizer.encode(f'{initial_text}{s}', return_tensors='pt')[0]} for s in strings]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "padded['decoder_input_ids'] = shift_tokens_right(padded['input_ids'], \n",
    "                                                 model.config.pad_token_id, model.config.decoder_start_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 12, 256])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = model(**padded)\n",
    "hidden_states = outputs[0]\n",
    "hidden_states.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 12])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eos_mask = padded['input_ids'].eq(model.config.eos_token_id)\n",
    "eos_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 256])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[\n",
    "    :, -1, :\n",
    "]\n",
    "sentence_representation.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "\n",
    "class BartClassificationHead(nn.Module):\n",
    "    \"\"\"Head for sentence-level classification tasks.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        inner_dim: int,\n",
    "        num_classes: int,\n",
    "        pooler_dropout: float,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(input_dim, inner_dim)\n",
    "        self.dropout = nn.Dropout(p=pooler_dropout)\n",
    "        self.out_proj = nn.Linear(inner_dim, num_classes)\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = torch.tanh(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.out_proj(hidden_states)\n",
    "        return hidden_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "classification_head = BartClassificationHead(\n",
    "    model.config.d_model,\n",
    "    model.config.d_model,\n",
    "    3,\n",
    "    0.1,\n",
    ")\n",
    "model._init_weights(classification_head.dense)\n",
    "model._init_weights(classification_head.out_proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 12, 3])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classification_head(hidden_states).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classification_head(sentence_representation).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
